use crate::gaussian_noise::conversion::modular_variance_to_variance;
use crate::utils::square;

#[allow(clippy::too_many_arguments)]
pub fn variance_multi_bit_external_product_glwe(
    glwe_dimension: u64,
    polynomial_size: u64,
    log2_base: u64,
    level: u64,
    ciphertext_modulus_log: u32,
    fft_precision: u32,
    variance_ggsw: f64,
    grouping_factor: u32,
    jit_fft: bool,
) -> f64 {
    theoretical_variance_multi_bit_external_product_glwe(
        glwe_dimension,
        polynomial_size,
        log2_base,
        level,
        ciphertext_modulus_log,
        variance_ggsw,
        grouping_factor,
    ) + fft_noise_variance_multi_bit_external_product_glwe(
        glwe_dimension,
        polynomial_size,
        log2_base,
        level,
        ciphertext_modulus_log,
        fft_precision,
        grouping_factor,
        jit_fft,
    )
}

fn theoretical_variance_multi_bit_external_product_glwe(
    glwe_dimension: u64,
    polynomial_size: u64,
    log2_base: u64,
    level: u64,
    ciphertext_modulus_log: u32,
    variance_ggsw: f64,
    grouping_factor: u32,
) -> f64 {
    let variance_key_coefficient_binary: f64 =
        modular_variance_to_variance(1. / 4., ciphertext_modulus_log);

    let square_expectation_key_coefficient_binary: f64 =
        modular_variance_to_variance(square(1. / 2.), ciphertext_modulus_log);

    let k = glwe_dimension as f64;
    let b = 2_f64.powi(log2_base as i32);
    let b2l = 2_f64.powi((log2_base * 2 * level) as i32);
    let l = level as f64;
    let big_n = polynomial_size as f64;
    let q_square = 2_f64.powi(2 * ciphertext_modulus_log as i32);

    let res_1 = l * (k + 1.) * big_n * (square(b) + 2.) / 12.
        * variance_ggsw
        * 2.0f64.powi(grouping_factor as i32);
    let res_2 = (q_square - b2l) / (24. * b2l)
        * (modular_variance_to_variance(1., ciphertext_modulus_log)
            + k * big_n
                * (variance_key_coefficient_binary + square_expectation_key_coefficient_binary))
        + k * big_n / 8. * variance_key_coefficient_binary
        + 1. / 16. * square(1. - k * big_n) * square_expectation_key_coefficient_binary;

    res_1 + res_2
}

const FFT_SCALING_WEIGHTS: [(u32, f64); 3] = [
    (2, 0.265_753_885_551_084_5),
    (3, 1.350_324_550_016_489_8),
    (4, 2.475_036_769_207_096),
];
const JIT_FFT_SCALING_WEIGHT: f64 = -2.015_541_494_298_571_7;

/// Additional noise generated by fft computation
#[allow(clippy::too_many_arguments)]
fn fft_noise_variance_multi_bit_external_product_glwe(
    glwe_dimension: u64,
    polynomial_size: u64,
    log2_base: u64,
    level: u64,
    ciphertext_modulus_log: u32,
    fft_precision: u32,
    grouping_factor: u32,
    jit_fft: bool,
) -> f64 {
    let b = 2_f64.powi(log2_base as i32);
    let l = level as f64;
    let big_n = polynomial_size as f64;
    let k = glwe_dimension;
    assert!(k > 0, "k = {k}");
    assert!(k < 7, "k = {k}");

    let fft_scaling_weight = if jit_fft {
        JIT_FFT_SCALING_WEIGHT
    } else {
        let index = FFT_SCALING_WEIGHTS
            .binary_search_by_key(&grouping_factor, |&(factor, _)| factor)
            .unwrap_or_else(|_| {
                panic!("Could not find fft scaling weight for grouping factor {grouping_factor}.")
            });
        FFT_SCALING_WEIGHTS[index].1
    };

    let lost_bits = ciphertext_modulus_log as i32 - fft_precision as i32;

    let scale_margin = 2_f64.powi(2 * lost_bits);

    let res =
        f64::exp2(fft_scaling_weight) * scale_margin * l * b * b * big_n.powi(2) * (k as f64 + 1.);
    modular_variance_to_variance(res, ciphertext_modulus_log)
}
